Skip to content

Just another reasonably minimal repo for class-conditional training of pixel-space diffusion transformers.

Notifications You must be signed in to change notification settings

sayakpaul/nanoDiT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 

Repository files navigation

nanoDiT

An educational repository to show rectified-flow training of class-conditional DiTs for image generation (~600 LoC). The repository draws inspiration from legendary works like nanoGPT, nanoVLM, and tries to do something similar for rectified-flow based image generation. It is meant to be hackable and a friendly starting point for folks wanting to get started in the area.

Repo structure

Core DiT implementation is in model.py, adapted heavily from the original implementation. Training is implemented in the train.py file. The repo has two dependencies -- torch and torchvision (needed for data loading).

Getting started

The repo implements class-conditional DiT training with the butterflies dataset. Download the dataset:

from huggingface_hub import snapshot_download

snapshot_download(
    repo_id="sayakpaul/butteflies_with_classes", repo_type="dataset", local_dir="butterflies"
)

Start training:

python train.py

train.py has reasonable defaults set as constants at the top of the file. Feel free to modify them as needed. Running the script as is yields the following loss curve:

The final step of intermediate result visualization yields:

By default, we train on the 64x64 resolution to keep the runtime short and memory requirements lower.

If you're running training on MPS, please change the DEVICE accordingly. The iterative sampling process for inference is implemented in the sample_conditional_cfg() function already.

Notes

The repo aims at bringing the essential components while striving to be minimal and reasonable:

  • Pixel-space instead of latent-space because VAE is a separate magic which was purposefully discarded from this repo.
  • DiT as the base architecture, as this is the core behind modern models like Flux.
  • Class-conditional to show how to embed conditions other than just timesteps.
  • Rectified flows for a simpler alternative to DDPM-style diffusion training.
  • You can fiddle around with the model and the training setup on your local laptop.
  • Implements classifier-free guidance for training and inference, as it's a common practice.

About

Just another reasonably minimal repo for class-conditional training of pixel-space diffusion transformers.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages